{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fDZ1cweXj8s3"
      },
      "source": [
        "Sora from OpenAI, Stable Video Diffusion from Stability AI, and many other text-to-video models that have come out or will appear in the future are among the most popular AI trends in 2024, following large language models (LLMs). In this blog, we will build a **small scale text-to-video model from scratch**. We will input a text prompt, and our trained model will generate a video based on that prompt. This blog will cover everything from understanding the theoretical concepts to coding the entire architecture and generating the final result.\n",
        "\n",
        "Since I don’t have a fancy GPU, I’ve coded the small-scale architecture. Here’s a comparison of the time required to train the model on different processors:\n",
        "\n",
        "| Training Videos | Epochs | CPU      | GPU A10 | GPU T4    |\n",
        "|---------------|--------|----------|---------|-----------|\n",
        "| 10K           | 30     | more than 3 hr    | 1 hr    | 1 hr 42m  |\n",
        "| 30K           | 30     | more than 6 hr    | 1 hr 30 | 2 hr 30   |\n",
        "| 100K          | 30     | -        | 3-4 hr  | 5-6 hr    |\n",
        "\n",
        "Running on a CPU will obviously take much longer to train the model. If you need to quickly test changes in the code and see results, CPU is not the best choice. I recommend using a T4 GPU from [Colab](https://colab.research.google.com/) or [Kaggle](https://kaggle.com/) for more efficient and faster training.\n",
        "\n",
        "Here is the blog link which guides you on how to create Stable Diffusion from scratch:\n",
        "[Coding Stable Diffusion from Scratch](https://levelup.gitconnected.com/building-stable-diffusion-from-scratch-using-python-f3ebc8c42da3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_BnTekhJj8s6"
      },
      "source": [
        "## What We’re Building\n",
        "\n",
        "We will follow a similar approach to traditional machine learning or deep learning models that train on a dataset and are then tested on unseen data. In the context of text-to-video, let’s say we have a training dataset of 100K videos of dogs fetching balls and cats chasing mice. We will train our model to generate videos of a cat fetching a ball or a dog chasing a mouse.\n",
        "\n",
        "<img src=\"https://cdn-images-1.medium.com/max/3840/1*6h3oJzGEH0xrER2Tv8M7KQ.gif\" width=\"900\">\n",
        "\n",
        "Although such training datasets are easily available on the internet, the required computational power is extremely high. Therefore, we will work with a video dataset of moving objects generated from Python code.\n",
        "\n",
        "We will use the GAN (Generative Adversarial Networks) architecture to create our model instead of the diffusion model that OpenAI Sora uses. I attempted to use the diffusion model, but it crashed due to memory requirements, which is beyond my capacity. GANs, on the other hand, are easier and quicker to train and test."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8zSgLUkmj8s6"
      },
      "source": [
        "## Prerequisites\n",
        "\n",
        "We will be using OOP (Object-Oriented Programming), so you must have a basic understanding of it along with neural networks. Knowledge of GANs (Generative Adversarial Networks) is not mandatory, as we will be covering their architecture here.\n",
        "\n",
        "| Topic | Link |\n",
        "| ---- | ---- |\n",
        "| OOP | [Video Link](https://www.youtube.com/watch?v=q2SGW2VgwAM) |\n",
        "| Neural Networks Theory |  [Video Link](https://www.youtube.com/watch?v=Jy4wM2X21u0) |\n",
        "| GAN Architecture |  [Video Link](https://www.youtube.com/watch?v=TpMIssRdhco) |\n",
        "| Python basics |  [Video Link](https://www.youtube.com/watch?v=eWRfhZUzrAc) |\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNm7BMx4j8s6"
      },
      "source": [
        "# Understanding the GAN Architecture\n",
        "\n",
        "Understanding GAN architecture is important because much of our architecture depends on it. Let’s explore what it is, its components, and more.\n",
        "\n",
        "### what is **GAN**?\n",
        "\n",
        "Generative Adversarial Network (GAN) is a deep learning model where two neural networks compete: one creates new data (like images or music) from a given dataset, and the other tries to tell if the data is real or fake. This process continues until the generated data is indistinguishable from the original.\n",
        "\n",
        "### Real-World Application\n",
        "\n",
        " 1. **Generate Images**: GANs create realistic images from text prompts or modify existing images, such as enhancing resolution or adding color to black-and-white photos.\n",
        "\n",
        " 2. **Data Augmentation**: They generate synthetic data to train other machine learning models, such as creating fraudulent transaction data for fraud detection systems.\n",
        "\n",
        " 3. **Complete Missing Information**: GANs can fill in missing data, like generating sub-surface images from terrain maps for energy applications.\n",
        "\n",
        " 4. **Generate 3D Models**: They convert 2D images into 3D models, useful in fields like healthcare for creating realistic organ images for surgical planning.\n",
        "\n",
        "### How does a GAN work?\n",
        "\n",
        "It consists of two deep neural networks: the **generator** and the **discriminator**. These networks train together in an adversarial setup, where one generates new data and the other evaluates if the data is real or fake.\n",
        "\n",
        "Here’s a simplified overview of how GAN works:\n",
        "\n",
        " 1. **Training Set Analysis**: The generator analyzes the training set to identify data attributes, while the discriminator independently analyzes the same data to learn its attributes.\n",
        "\n",
        " 2. **Data Modification**: The generator adds noise (random changes) to some attributes of the data.\n",
        "\n",
        " 3. **Data Passing**: The modified data is then passed to the discriminator.\n",
        "\n",
        " 4. **Probability Calculation**: The discriminator calculates the probability that the generated data is from the original dataset.\n",
        "\n",
        " 5. **Feedback Loop**: The discriminator provides feedback to the generator, guiding it to reduce random noise in the next cycle.\n",
        "\n",
        " 6. **Adversarial Training**: The generator tries to maximize the discriminator’s mistakes, while the discriminator tries to minimize its own errors. Through many training iterations, both networks improve and evolve.\n",
        "\n",
        " 7. **Equilibrium State**: Training continues until the discriminator can no longer distinguish between real and synthesized data, indicating that the generator has successfully learned to produce realistic data. At this point, the training process is complete.\n",
        "\n",
        "![From [AWS Guide](https://aws.amazon.com/what-is/gan/)](https://cdn-images-1.medium.com/max/2796/1*2HsK-UFPRvCdAmQyS3Ol1Q.jpeg)\n",
        "\n",
        "### GAN training example\n",
        "\n",
        "Let’s explain the GAN model with an example of image-to-image translation, focusing on modifying a human face.\n",
        "\n",
        " 1. **Input Image**: The input is a real image of a human face.\n",
        "\n",
        " 2. **Attribute Modification**: The generator modifies attributes of the face, like adding sunglasses to the eyes.\n",
        "\n",
        " 3. **Generated Images**: The generator creates a set of images with sunglasses added.\n",
        "\n",
        " 4. **Discriminator’s Task**: The discriminator receives a mix of real images (people with sunglasses) and generated images (faces where sunglasses were added).\n",
        "\n",
        " 5. **Evaluation**: The discriminator tries to differentiate between real and generated images.\n",
        "\n",
        " 6. **Feedback Loop**: If the discriminator correctly identifies fake images, the generator adjusts its parameters to produce more convincing images. If the generator successfully fools the discriminator, the discriminator updates its parameters to improve its detection.\n",
        "\n",
        "Through this adversarial process, both networks continuously improve. The generator gets better at creating realistic images, and the discriminator gets better at identifying fakes until equilibrium is reached, where the discriminator can no longer tell the difference between real and generated images. At this point, the GAN has successfully learned to produce realistic modifications."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GFFrVzR9j8s7"
      },
      "source": [
        "## Importing Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "_vABkDvGj8s7"
      },
      "outputs": [],
      "source": [
        "# Operating System module for interacting with the operating system\n",
        "import os\n",
        "\n",
        "# Module for generating random numbers\n",
        "import random\n",
        "\n",
        "# Module for numerical operations\n",
        "import numpy as np\n",
        "\n",
        "# OpenCV library for image processing\n",
        "import cv2\n",
        "\n",
        "# Python Imaging Library for image processing\n",
        "from PIL import Image, ImageDraw, ImageFont\n",
        "\n",
        "# PyTorch library for deep learning\n",
        "import torch\n",
        "\n",
        "# Dataset class for creating custom datasets in PyTorch\n",
        "from torch.utils.data import Dataset\n",
        "\n",
        "# Module for image transformations\n",
        "import torchvision.transforms as transforms\n",
        "\n",
        "# Neural network module in PyTorch\n",
        "import torch.nn as nn\n",
        "\n",
        "# Optimization algorithms in PyTorch\n",
        "import torch.optim as optim\n",
        "\n",
        "# Function for padding sequences in PyTorch\n",
        "from torch.nn.utils.rnn import pad_sequence\n",
        "\n",
        "# Function for saving images in PyTorch\n",
        "from torchvision.utils import save_image\n",
        "\n",
        "# Module for plotting graphs and images\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Module for displaying rich content in IPython environments\n",
        "from IPython.display import clear_output, display, HTML\n",
        "\n",
        "# Module for encoding and decoding binary data to text\n",
        "import base64"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mg5sx3Qnj8s8"
      },
      "source": [
        "## Coding the Training Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8Lh9GLYXj8s8",
        "outputId": "0285032c-b1cd-4ab7-ba35-311484b27b52"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset generation complete.\n"
          ]
        }
      ],
      "source": [
        "# Create the directory for the training dataset\n",
        "os.makedirs('training_dataset', exist_ok=True)\n",
        "\n",
        "# Define the number of videos to generate for the dataset\n",
        "num_videos = 30000\n",
        "\n",
        "# Define the number of frames per video (1 Second Video)\n",
        "frames_per_video = 10\n",
        "\n",
        "# Define the size of each image in the dataset\n",
        "img_size = (64, 64)\n",
        "\n",
        "# Define the size of the shapes (Circle)\n",
        "shape_size = 10\n",
        "\n",
        "# Define text prompts and corresponding movements for circles\n",
        "prompts_and_movements = [\n",
        "    (\"circle moving down\", \"circle\", \"down\"),  # Move circle downward\n",
        "    (\"circle moving left\", \"circle\", \"left\"),  # Move circle leftward\n",
        "    (\"circle moving right\", \"circle\", \"right\"),  # Move circle rightward\n",
        "    (\"circle moving diagonally up-right\", \"circle\", \"diagonal_up_right\"),  # Move circle diagonally up-right\n",
        "    (\"circle moving diagonally down-left\", \"circle\", \"diagonal_down_left\"),  # Move circle diagonally down-left\n",
        "    (\"circle moving diagonally up-left\", \"circle\", \"diagonal_up_left\"),  # Move circle diagonally up-left\n",
        "    (\"circle moving diagonally down-right\", \"circle\", \"diagonal_down_right\"),  # Move circle diagonally down-right\n",
        "    (\"circle rotating clockwise\", \"circle\", \"rotate_clockwise\"),  # Rotate circle clockwise\n",
        "    (\"circle rotating counter-clockwise\", \"circle\", \"rotate_counter_clockwise\"),  # Rotate circle counter-clockwise\n",
        "    (\"circle bouncing vertically\", \"circle\", \"bounce_vertical\"),  # Bounce circle vertically\n",
        "    (\"circle bouncing horizontally\", \"circle\", \"bounce_horizontal\"),  # Bounce circle horizontally\n",
        "    (\"circle zigzagging vertically\", \"circle\", \"zigzag_vertical\"),  # Zigzag circle vertically\n",
        "    (\"circle zigzagging horizontally\", \"circle\", \"zigzag_horizontal\"),  # Zigzag circle horizontally\n",
        "    (\"circle moving up-left\", \"circle\", \"up_left\"),  # Move circle up-left\n",
        "    (\"circle moving down-right\", \"circle\", \"down_right\"),  # Move circle down-right\n",
        "    (\"circle moving down-left\", \"circle\", \"down_left\")  # Move circle down-left\n",
        "]\n",
        "\n",
        "# Define a function to create an image with a moving shape\n",
        "def create_image_with_moving_shape(size, frame_num, shape, direction):\n",
        "    # Create a new RGB image with the specified size and white background\n",
        "    img = Image.new('RGB', size, color=(255, 255, 255))\n",
        "    draw = ImageDraw.Draw(img)\n",
        "\n",
        "    # Calculate the initial position of the shape (center of the image)\n",
        "    center_x, center_y = size[0] // 2, size[1] // 2\n",
        "\n",
        "    # Determine the shape position based on the movement direction\n",
        "    if direction == \"down\":\n",
        "        position = (center_x, (center_y + frame_num * 5) % size[1])\n",
        "    elif direction == \"left\":\n",
        "        position = ((center_x - frame_num * 5) % size[0], center_y)\n",
        "    elif direction == \"right\":\n",
        "        position = ((center_x + frame_num * 5) % size[0], center_y)\n",
        "    elif direction == \"diagonal_up_right\":\n",
        "        position = ((center_x + frame_num * 5) % size[0], (center_y - frame_num * 5) % size[1])\n",
        "    elif direction == \"diagonal_down_left\":\n",
        "        position = ((center_x - frame_num * 5) % size[0], (center_y + frame_num * 5) % size[1])\n",
        "    elif direction == \"diagonal_up_left\":\n",
        "        position = ((center_x - frame_num * 5) % size[0], (center_y - frame_num * 5) % size[1])\n",
        "    elif direction == \"diagonal_down_right\":\n",
        "        position = ((center_x + frame_num * 5) % size[0], (center_y + frame_num * 5) % size[1])\n",
        "    elif direction == \"rotate_clockwise\":\n",
        "        img = img.rotate(frame_num * 10, center=(center_x, center_y), fillcolor=(255, 255, 255))\n",
        "        position = (center_x, center_y)\n",
        "    elif direction == \"rotate_counter_clockwise\":\n",
        "        img = img.rotate(-frame_num * 10, center=(center_x, center_y), fillcolor=(255, 255, 255))\n",
        "        position = (center_x, center_y)\n",
        "    elif direction == \"bounce_vertical\":\n",
        "        position = (center_x, center_y - abs(frame_num * 5 % size[1] - center_y))\n",
        "    elif direction == \"bounce_horizontal\":\n",
        "        position = (center_x - abs(frame_num * 5 % size[0] - center_x), center_y)\n",
        "    elif direction == \"zigzag_vertical\":\n",
        "        position = (center_x, center_y - frame_num * 5 % size[1] if frame_num % 2 == 0 else center_y + frame_num * 5 % size[1])\n",
        "    elif direction == \"zigzag_horizontal\":\n",
        "        position = (center_x - frame_num * 5 % size[0] if frame_num % 2 == 0 else center_x + frame_num * 5 % size[0], center_y)\n",
        "    elif direction == \"up_left\":\n",
        "        position = ((center_x - frame_num * 5) % size[0], (center_y - frame_num * 5) % size[1])\n",
        "    elif direction == \"down_right\":\n",
        "        position = ((center_x + frame_num * 5) % size[0], (center_y + frame_num * 5) % size[1])\n",
        "    elif direction == \"down_left\":\n",
        "        position = ((center_x - frame_num * 5) % size[0], (center_y + frame_num * 5) % size[1])\n",
        "    else:\n",
        "        position = (center_x, center_y)\n",
        "\n",
        "    # Draw the shape (circle) at the calculated position\n",
        "    if shape == \"circle\":\n",
        "        draw.ellipse([position[0] - shape_size // 2, position[1] - shape_size // 2, position[0] + shape_size // 2, position[1] + shape_size // 2], fill=(0, 0, 255))\n",
        "\n",
        "    # Return the image as a numpy array\n",
        "    return np.array(img)\n",
        "\n",
        "# Generate the dataset\n",
        "for video_num in range(num_videos):\n",
        "    prompt, shape, direction = random.choice(prompts_and_movements)\n",
        "    video_frames = []\n",
        "    for frame_num in range(frames_per_video):\n",
        "        img_array = create_image_with_moving_shape(img_size, frame_num, shape, direction)\n",
        "        video_frames.append(img_array)\n",
        "\n",
        "    # Save the frames as images in the training dataset directory\n",
        "    video_dir = os.path.join('training_dataset', f'video_{video_num}')\n",
        "    os.makedirs(video_dir, exist_ok=True)\n",
        "    for frame_num, frame in enumerate(video_frames):\n",
        "        frame_image = Image.fromarray(frame)\n",
        "        frame_image.save(os.path.join(video_dir, f'frame_{frame_num}.png'))\n",
        "\n",
        "print(\"Dataset generation complete.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "IF10vxGej8s8"
      },
      "outputs": [],
      "source": [
        "# Iterate over the number of videos to generate\n",
        "for i in range(num_videos):\n",
        "    # Randomly choose a prompt and movement from the predefined list\n",
        "    prompt, shape, direction = random.choice(prompts_and_movements)\n",
        "\n",
        "    # Create a directory for the current video\n",
        "    video_dir = f'training_dataset/video_{i}'\n",
        "    os.makedirs(video_dir, exist_ok=True)\n",
        "\n",
        "    # Write the chosen prompt to a text file in the video directory\n",
        "    with open(f'{video_dir}/prompt.txt', 'w') as f:\n",
        "        f.write(prompt)\n",
        "\n",
        "    # Generate frames for the current video\n",
        "    for frame_num in range(frames_per_video):\n",
        "        # Create an image with a moving shape based on the current frame number, shape, and direction\n",
        "        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)\n",
        "\n",
        "        # Save the generated image as a PNG file in the video directory\n",
        "        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HXLQf-Wkj8s9"
      },
      "source": [
        "## Pre-Processing Our Training Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JlSlYNKpj8s9"
      },
      "outputs": [],
      "source": [
        "# Define a dataset class inheriting from torch.utils.data.Dataset\n",
        "class TextToVideoDataset(Dataset):\n",
        "    def __init__(self, root_dir, transform=None):\n",
        "        # Initialize the dataset with root directory and optional transform\n",
        "        self.root_dir = root_dir\n",
        "        self.transform = transform\n",
        "        # List all subdirectories in the root directory\n",
        "        self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]\n",
        "        # Initialize lists to store frame paths and corresponding prompts\n",
        "        self.frame_paths = []\n",
        "        self.prompts = []\n",
        "\n",
        "        # Loop through each video directory\n",
        "        for video_dir in self.video_dirs:\n",
        "            # List all PNG files in the video directory and store their paths\n",
        "            frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]\n",
        "            self.frame_paths.extend(frames)\n",
        "            # Read the prompt text file in the video directory and store its content\n",
        "            with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:\n",
        "                prompt = f.read().strip()\n",
        "            # Repeat the prompt for each frame in the video and store in prompts list\n",
        "            self.prompts.extend([prompt] * len(frames))\n",
        "\n",
        "    # Return the total number of samples in the dataset\n",
        "    def __len__(self):\n",
        "        return len(self.frame_paths)\n",
        "\n",
        "    # Retrieve a sample from the dataset given an index\n",
        "    def __getitem__(self, idx):\n",
        "        # Get the path of the frame corresponding to the given index\n",
        "        frame_path = self.frame_paths[idx]\n",
        "        # Open the image using PIL (Python Imaging Library)\n",
        "        image = Image.open(frame_path)\n",
        "        # Get the prompt corresponding to the given index\n",
        "        prompt = self.prompts[idx]\n",
        "\n",
        "        # Apply transformation if specified\n",
        "        if self.transform:\n",
        "            image = self.transform(image)\n",
        "\n",
        "        # Return the transformed image and the prompt\n",
        "        return image, prompt\n",
        "\n",
        "# Define a set of transformations to be applied to the data\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(), # Convert PIL Image or numpy.ndarray to tensor\n",
        "    transforms.Normalize((0.5,), (0.5,)) # Normalize image with mean and standard deviation\n",
        "])\n",
        "\n",
        "# Load the dataset using the defined transform\n",
        "dataset = TextToVideoDataset(root_dir='training_dataset', transform=transform)\n",
        "# Create a dataloader to iterate over the dataset\n",
        "dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OvaajlG2j8s9"
      },
      "source": [
        "## Implementing GAN Architecture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QKgbkSNaj8s9"
      },
      "outputs": [],
      "source": [
        "# Define a class for text embedding\n",
        "class TextEmbedding(nn.Module):\n",
        "    # Constructor method with vocab_size and embed_size parameters\n",
        "    def __init__(self, vocab_size, embed_size):\n",
        "        # Call the superclass constructor\n",
        "        super(TextEmbedding, self).__init__()\n",
        "        # Initialize embedding layer\n",
        "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
        "\n",
        "    # Define the forward pass method\n",
        "    def forward(self, x):\n",
        "        # Return embedded representation of input\n",
        "        return self.embedding(x)\n",
        "\n",
        "class Generator(nn.Module):\n",
        "    def __init__(self, text_embed_size):\n",
        "        super(Generator, self).__init__()\n",
        "\n",
        "        # Fully connected layer that takes noise and text embedding as input\n",
        "        self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)\n",
        "\n",
        "        # Transposed convolutional layers to upsample the input\n",
        "        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)\n",
        "        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)\n",
        "        self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)  # Output has 3 channels for RGB images\n",
        "\n",
        "        # Activation functions\n",
        "        self.relu = nn.ReLU(True)  # ReLU activation function\n",
        "        self.tanh = nn.Tanh()       # Tanh activation function for final output\n",
        "\n",
        "    def forward(self, noise, text_embed):\n",
        "        # Concatenate noise and text embedding along the channel dimension\n",
        "        x = torch.cat((noise, text_embed), dim=1)\n",
        "\n",
        "        # Fully connected layer followed by reshaping to 4D tensor\n",
        "        x = self.fc1(x).view(-1, 256, 8, 8)\n",
        "\n",
        "        # Upsampling through transposed convolution layers with ReLU activation\n",
        "        x = self.relu(self.deconv1(x))\n",
        "        x = self.relu(self.deconv2(x))\n",
        "\n",
        "        # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)\n",
        "        x = self.tanh(self.deconv3(x))\n",
        "\n",
        "        return x\n",
        "\n",
        "class Discriminator(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Discriminator, self).__init__()\n",
        "\n",
        "        # Convolutional layers to process input images\n",
        "        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1\n",
        "        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1\n",
        "        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1\n",
        "\n",
        "        # Fully connected layer for classification\n",
        "        self.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)\n",
        "\n",
        "        # Activation functions\n",
        "        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2\n",
        "        self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)\n",
        "\n",
        "    def forward(self, input):\n",
        "        # Pass input through convolutional layers with LeakyReLU activation\n",
        "        x = self.leaky_relu(self.conv1(input))\n",
        "        x = self.leaky_relu(self.conv2(x))\n",
        "        x = self.leaky_relu(self.conv3(x))\n",
        "\n",
        "        # Flatten the output of convolutional layers\n",
        "        x = x.view(-1, 256 * 8 * 8)\n",
        "\n",
        "        # Pass through fully connected layer with Sigmoid activation for binary classification\n",
        "        x = self.sigmoid(self.fc1(x))\n",
        "\n",
        "        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GfCQkirWj8s9"
      },
      "source": [
        "## Coding Training Parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L9Zrz1szj8s9"
      },
      "outputs": [],
      "source": [
        "# Check for GPU\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Create a simple vocabulary for text prompts\n",
        "all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list\n",
        "vocab = {word: idx for idx, word in enumerate(set(\" \".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index\n",
        "vocab_size = len(vocab)  # Size of the vocabulary\n",
        "embed_size = 10  # Size of the text embedding vector\n",
        "\n",
        "def encode_text(prompt):\n",
        "    # Encode a given prompt into a tensor of indices using the vocabulary\n",
        "    return torch.tensor([vocab[word] for word in prompt.split()])\n",
        "\n",
        "# Initialize models, loss function, and optimizers\n",
        "text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size\n",
        "netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size\n",
        "netD = Discriminator().to(device)  # Initialize Discriminator model\n",
        "criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function\n",
        "optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator\n",
        "optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X2tChYiEj8s9"
      },
      "source": [
        "## Training Loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mlSboIs5j8s9"
      },
      "outputs": [],
      "source": [
        "# Number of epochs\n",
        "num_epochs = 13\n",
        "\n",
        "# Iterate over each epoch\n",
        "for epoch in range(num_epochs):\n",
        "    # Iterate over each batch of data\n",
        "    for i, (data, prompts) in enumerate(dataloader):\n",
        "        # Move real data to device\n",
        "        real_data = data.to(device)\n",
        "\n",
        "        # Convert prompts to list\n",
        "        prompts = [prompt for prompt in prompts]\n",
        "\n",
        "        # Update Discriminator\n",
        "        netD.zero_grad()  # Zero the gradients of the Discriminator\n",
        "        batch_size = real_data.size(0)  # Get the batch size\n",
        "        labels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)\n",
        "        output = netD(real_data)  # Forward pass real data through Discriminator\n",
        "        lossD_real = criterion(output, labels)  # Calculate loss on real data\n",
        "        lossD_real.backward()  # Backward pass to calculate gradients\n",
        "\n",
        "        # Generate fake data\n",
        "        noise = torch.randn(batch_size, 100).to(device)  # Generate random noise\n",
        "        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddings\n",
        "        fake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddings\n",
        "        labels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)\n",
        "        output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)\n",
        "        lossD_fake = criterion(output, labels)  # Calculate loss on fake data\n",
        "        lossD_fake.backward()  # Backward pass to calculate gradients\n",
        "        optimizerD.step()  # Update Discriminator parameters\n",
        "\n",
        "        # Update Generator\n",
        "        netG.zero_grad()  # Zero the gradients of the Generator\n",
        "        labels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminator\n",
        "        output = netD(fake_data)  # Forward pass fake data (now updated) through Discriminator\n",
        "        lossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's response\n",
        "        lossG.backward()  # Backward pass to calculate gradients\n",
        "        optimizerG.step()  # Update Generator parameters\n",
        "\n",
        "    # Print epoch information\n",
        "    print(f\"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "veLNs63Xj8s-"
      },
      "source": [
        "## Saving the Trained Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IixqmS-kj8s-"
      },
      "outputs": [],
      "source": [
        "# Save the Generator model's state dictionary to a file named 'generator.pth'\n",
        "torch.save(netG.state_dict(), 'generator.pth')\n",
        "\n",
        "# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'\n",
        "torch.save(netD.state_dict(), 'discriminator.pth')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dNlKX9P0j8s-"
      },
      "source": [
        "## Generating AI Video\n",
        "\n",
        "As we discussed, our approach to test our model on unseen data is comparable to the example where our training data involves dogs fetching balls and cats chasing mice. Therefore, our test prompt could involve scenarios like a cat fetching a ball or a dog chasing a mouse.\n",
        "In our specific case, the motion where the circle moves up and then to the right is not present in our training data, so the model is unfamiliar with this specific motion. However, it has been trained on other motions. We can use this motion as a prompt to test our trained model and observe its performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RALjTc8Tj8s-"
      },
      "outputs": [],
      "source": [
        "# Inference function to generate a video based on a given text prompt\n",
        "def generate_video(text_prompt, num_frames=10):\n",
        "    # Create a directory for the generated video frames based on the text prompt\n",
        "    os.makedirs(f'generated_video_{text_prompt.replace(\" \", \"_\")}', exist_ok=True)\n",
        "\n",
        "    # Encode the text prompt into a text embedding tensor\n",
        "    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)\n",
        "\n",
        "    # Generate frames for the video\n",
        "    for frame_num in range(num_frames):\n",
        "        # Generate random noise\n",
        "        noise = torch.randn(1, 100).to(device)\n",
        "\n",
        "        # Generate a fake frame using the Generator network\n",
        "        with torch.no_grad():\n",
        "            fake_frame = netG(noise, text_embed)\n",
        "\n",
        "        # Save the generated fake frame as an image file\n",
        "        save_image(fake_frame, f'generated_video_{text_prompt.replace(\" \", \"_\")}/frame_{frame_num}.png')\n",
        "\n",
        "# usage of the generate_video function with a specific text prompt\n",
        "generate_video('circle moving up-right')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7iymU6CVj8s-"
      },
      "source": [
        "When we run the above code, it will generate a directory containing all the frames of our generated video. We need to use a bit of code to merge all these frames into a single short video."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gl9z6qoNj8s-"
      },
      "outputs": [],
      "source": [
        "# Define the path to your folder containing the PNG frames\n",
        "folder_path = 'generated_video_circle_moving_up-right'\n",
        "\n",
        "# Get the list of all PNG files in the folder\n",
        "image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]\n",
        "\n",
        "# Sort the images by name (assuming they are numbered sequentially)\n",
        "image_files.sort()\n",
        "\n",
        "# Create a list to store the frames\n",
        "frames = []\n",
        "\n",
        "# Read each image and append it to the frames list\n",
        "for image_file in image_files:\n",
        "  image_path = os.path.join(folder_path, image_file)\n",
        "  frame = cv2.imread(image_path)\n",
        "  frames.append(frame)\n",
        "\n",
        "# Convert the frames list to a numpy array for easier processing\n",
        "frames = np.array(frames)\n",
        "\n",
        "# Define the frame rate (frames per second)\n",
        "fps = 10\n",
        "\n",
        "# Create a video writer object\n",
        "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
        "out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))\n",
        "\n",
        "# Write each frame to the video\n",
        "for frame in frames:\n",
        "  out.write(frame)\n",
        "\n",
        "# Release the video writer\n",
        "out.release()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Sqy1GxSRj8s-"
      },
      "source": [
        "Make sure the folder path points to where your newly generated video exists. After running this code, your AI video will have been successfully created. Let's see what it looks like."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VRLTU_onj8s-"
      },
      "source": [
        "## What's Missing?\n",
        "I've tested various aspects of this architecture, and found that the training data is the key. By including more motions and shapes in the dataset, you can increase variability and improve the model's performance. Since the data is generated through code, generating more varied data won't take much time; instead, you can focus on refining the logic.\n",
        "Furthermore, the GAN architecture discussed in this blog is relatively straightforward. You can make it more complex by integrating advanced techniques or using a language model embedding (LLM) instead of a basic neural network embedding. Additionally, tuning parameters such as embedding size and others can significantly impact the model's effectiveness."
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python"
    },
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}